{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "No module named 'utils'\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    import warnings\n",
    "    warnings.filterwarnings('ignore')\n",
    "    import utils\n",
    "except Exception as e:\n",
    "    print(e)\n",
    "    from tensorflow.keras.applications.vgg16 import preprocess_input\n",
    "    def preprocess_img(x):\n",
    "        return preprocess_input(x, mode='tf')\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from keras.preprocessing.image import ImageDataGenerator\n",
    "from glob import glob\n",
    "import tensorflow as tf\n",
    "import keras\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from keras.applications.inception_v3 import InceptionV3\n",
    "from keras.applications.xception import Xception\n",
    "from keras.applications.resnet50 import ResNet50\n",
    "from keras.models import Model, Sequential, Input\n",
    "from keras.layers import *\n",
    "from keras.optimizers import Adam,SGD\n",
    "from keras.callbacks import ModelCheckpoint, EarlyStopping\n",
    "import os\n",
    "import PIL\n",
    "import time\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_data = '../garbage_classify/train_data'\n",
    "batch_size = 32\n",
    "img_width = 224\n",
    "img_height = 224\n",
    "img_size = 224\n",
    "random_seed = 201908\n",
    "path_data_train = '../tmp/data_train/'\n",
    "path_data_valid = '../tmp/data_valid/'\n",
    "labels_file = '../tmp/labels_raw.csv'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.axes._subplots.AxesSubplot at 0x7f33c6638978>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAEGCAYAAACJnEVTAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAW6UlEQVR4nO3dfZBV9X3H8fdHWJaChidXxwgW7BAE4hNeiSapwTwIaFtiYhJSOyHaDDFBm2SmNthkah4mnclTWzER3RpNMiVBQ4PZVCNgakzaEeESCcqTbJEMS1RWiWhIFNBv/7i/hQvZu3tXdvfend/nNXPnnvM7v3P5njPD5579nXPPUURgZmZ5OK7WBZiZWf9x6JuZZcShb2aWEYe+mVlGHPpmZhkZXOsCunLiiSfG+PHja12GmdmAsm7dumcjoqmzZXUd+uPHj6dYLNa6DDOzAUXSryst8/COmVlGHPpmZhlx6JuZZaSux/TNzI524MAB2traeOmll2pdSs0NHTqUsWPH0tDQUPU6Dn0zG1Da2to44YQTGD9+PJJqXU7NRATPPfccbW1tTJgwoer1qhrekTRS0jJJWyRtlnShpC9K2iBpvaSVkl6f+krSIkmtafm0ss+ZJ2lbes3r8VaaWfZeeuklxowZk3XgA0hizJgxPf6Lp9ox/ZuA+yPiDOBsYDPw1Yg4KyLOAf4L+KfUdzYwMb3mA4tTgaOBG4E3AdOBGyWN6lG1ZmaQfeB3eC37odvQlzQCuAj4FkBE7I+I5yPihbJuw4GOezTPAb4bJauBkZJOAWYCqyJiT0T8FlgFzOpxxWZm9ppVc6Q/AWgH7pT0qKTbJQ0HkPQlSTuBKzl8pH8qsLNs/bbUVqn9CJLmSypKKra3t/d4g8zM+sOiRYuYPHkyV155Za1L6ZFqQn8wMA1YHBHnAvuAhQAR8ZmIGAcsAa7tjYIiojkiChFRaGrq9FfEZmY1d8stt7Bq1SqWLFlS61J6pJrQbwPaIuKRNL+M0pdAuSXAe9P0LmBc2bKxqa1Su5nZgHLNNdewfft2Zs+ezYgRI7j66quZMWMGp59+OosWLTrU793vfjfnnXceU6dOpbm5+VD78ccfz/XXX8/UqVN55zvfyZo1aw6t39LSAsArr7zC9ddfz/nnn89ZZ53Fbbfd1iu1q5rHJUr6BfCRiNgq6XOUxvCbI2JbWn4d8LaIuELSZZSO+i+ldNJ2UURMTydy13H4C+OXwHkRsafSv1soFML33jGzcps3b2by5MkAfP7HG9n0mxe6WaNnprz+ddz4l1O77ddxb7BvfOMbrFy5kgcffJAXX3yRSZMm8fTTT9PQ0MCePXsYPXo0f/jDHzj//PN56KGHDl15dN999zF79mwuv/xy9u3bx7333sumTZuYN28e69evp7m5md27d/PZz36Wl19+mbe85S384Ac/+KPLM8v3RwdJ6yKi0Fnd1V6nfx2wRNIQYDtwFXC7pEnAq8CvgWtS3/soBX4r8PvUl4jYI+mLwNrU7wtdBb6Z2UBx2WWX0djYSGNjIyeddBLPPPMMY8eOZdGiRSxfvhyAnTt3sm3bNsaMGcOQIUOYNat0HcuZZ55JY2MjDQ0NnHnmmezYsQOAlStXsmHDBpYtWwbA3r172bZtW4+uye9MVaEfEeuBo7813luhbwALKiy7A7ijJwWamVVSzRF5f2hsbDw0PWjQIA4ePMjPfvYzHnjgAR5++GGGDRvGjBkzDl1T39DQcOhyy+OOO+7Q+scddxwHDx4ESj++uvnmm5k5c2av1up775iZ9YG9e/cyatQohg0bxpYtW1i9enWP1p85cyaLFy/mwIEDADzxxBPs27fvmOvybRjMzPrArFmzuPXWW5k8eTKTJk3iggsu6NH6H/nIR9ixYwfTpk0jImhqauKee+455rqqOpFbKz6Ra2ZH6+zEZc56eiLXwztmZhlx6JuZZcShb2YDTj0PS/en17IfHPpmNqAMHTqU5557Lvvg77if/tChQ3u0nq/eMbMBZezYsbS1teEbMh5+clZPOPTNbEBpaGg45l+l5szDO2ZmGXHom5llxKFvZpYRh76ZWUYc+mZmGXHom5llxKFvZpYRh76ZWUYc+mZmGXHom5llxKFvZpYRh76ZWUYc+mZmGXHom5llxKFvZpaRqkJf0khJyyRtkbRZ0oWSvprmN0haLmlkWf8bJLVK2ippZln7rNTWKmlhX2yQmZlVVu2R/k3A/RFxBnA2sBlYBbwxIs4CngBuAJA0BZgLTAVmAbdIGiRpEPBNYDYwBfhg6mtmZv2k29CXNAK4CPgWQETsj4jnI2JlRBxM3VYDHc/smgMsjYiXI+JJoBWYnl6tEbE9IvYDS1NfMzPrJ9Uc6U8A2oE7JT0q6XZJw4/qczXwkzR9KrCzbFlbaqvUfgRJ8yUVJRX9DEwzs95VTegPBqYBiyPiXGAfcGg8XtJngIPAkt4oKCKaI6IQEYWmpqbe+EgzM0uqCf02oC0iHknzyyh9CSDpw8BfAFdGRKTlu4BxZeuPTW2V2s3MrJ90G/oR8TSwU9Kk1PQOYJOkWcA/AH8VEb8vW6UFmCupUdIEYCKwBlgLTJQ0QdIQSid7W3pxW8zMrBuDq+x3HbAkhfV24CpKId4IrJIEsDoiromIjZLuBjZRGvZZEBGvAEi6FlgBDALuiIiNvbo1ZmbWJR0elak/hUIhisVircswMxtQJK2LiEJny/yLXDOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCMOfTOzjDj0zcwy4tA3M8uIQ9/MLCNVhb6kkZKWSdoiabOkCyW9T9JGSa9KKhzV/wZJrZK2SppZ1j4rtbVKWtjbG2NmZl0bXGW/m4D7I+IKSUOAYcDzwHuA28o7SpoCzAWmAq8HHpD0hrT4m8C7gDZgraSWiNh07JthZmbV6Db0JY0ALgI+DBAR+4H9lEIfSUevMgdYGhEvA09KagWmp2WtEbE9rbc09XXom5n1k2qGdyYA7cCdkh6VdLuk4V30PxXYWTbfltoqtR9B0nxJRUnF9vb2KsozM7NqVRP6g4FpwOKIOBfYB/TZeHxENEdEISIKTU1NffXPmJllqZrQbwPaIuKRNL+M0pdAJbuAcWXzY1NbpXYzM+sn3YZ+RDwN7JQ0KTW9g67H4VuAuZIaJU0AJgJrgLXAREkT0snguamvmZn1k2qv3rkOWJLCejtwlaTLgZuBJuBeSesjYmZEbJR0N6UvhoPAgoh4BUDStcAKYBBwR0Rs7OXtMTOzLigial1DRYVCIYrFYq3LMDMbUCSti4hCZ8v8i1wzs4w49M3MMuLQNzPLiEPfzCwjDn0zs4w49M3MMuLQNzPLiEPfzCwjDn0zs4w49M3MMuLQNzPLiEPfzCwjDn0zs4w49M3MMuLQNzPLiEPfzCwjDn0zs4w49M3MMuLQNzPLiEPfzCwjDn0zs4w49M3MMuLQNzPLiEPfzCwjVYW+pJGSlknaImmzpAsljZa0StK29D4q9ZWkRZJaJW2QNK3sc+al/tskzeurjTIzs85Ve6R/E3B/RJwBnA1sBhYCP42IicBP0zzAbGBies0HFgNIGg3cCLwJmA7c2PFFYWZm/WNwdx0kjQAuAj4MEBH7gf2S5gAzUrfvAD8DPg3MAb4bEQGsTn8lnJL6roqIPelzVwGzgO9X+re3t+/jA7c9/Fq2y8zMOlHNkf4EoB24U9Kjkm6XNBw4OSKeSn2eBk5O06cCO8vWb0ttldqPIGm+pKKk4oEDB3q2NWZm1qVuj/RTn2nAdRHxiKSbODyUA0BEhKTojYIiohloBigUCnHXRy/sjY81M8vG3ddUXlbNkX4b0BYRj6T5ZZS+BJ5Jwzak991p+S5gXNn6Y1NbpXYzM+sn3YZ+RDwN7JQ0KTW9A9gEtAAdV+DMA36UpluAD6WreC4A9qZhoBXAJZJGpRO4l6Q2MzPrJ9UM7wBcByyRNATYDlxF6Qvjbkl/C/waeH/qex9wKdAK/D71JSL2SPoisDb1+0LHSV0zM+sfKl1kU58KhUIUi8Val2FmNqBIWhcRhc6W+Re5ZmYZceibmWXEoW9mlhGHvplZRhz6ZmYZceibmWXEoW9mlhGHvplZRhz6ZmYZceibmWXEoW9mlhGHvplZRhz6ZmYZceibmWXEoW9mlhGHvplZRhz6ZmYZceibmWXEoW9mlhGHvplZRhz6ZmYZceibmWXEoW9mlhGHvplZRqoKfUk7JD0mab2kYmo7W9LDqf3Hkl5X1v8GSa2StkqaWdY+K7W1SlrY+5tjZmZd6cmR/sURcU5EFNL87cDCiDgTWA5cDyBpCjAXmArMAm6RNEjSIOCbwGxgCvDB1NfMzPrJsQzvvAH4eZpeBbw3Tc8BlkbEyxHxJNAKTE+v1ojYHhH7gaWpr5mZ9ZNqQz+AlZLWSZqf2jZyOLTfB4xL06cCO8vWbUttldqPIGm+pKKkYnt7e5XlmZlZNaoN/bdGxDRKQzMLJF0EXA18XNI64ARgf28UFBHNEVGIiEJTU1NvfKSZmSWDq+kUEbvS+25Jy4HpEfE14BIASW8ALkvdd3H4qB9gbGqji3YzM+sH3R7pSxou6YSOaUpB/7ikk1LbccBngVvTKi3AXEmNkiYAE4E1wFpgoqQJkoZQOtnb0tsbZGZmlVVzpH8ysFxSR//vRcT9kj4haUHq80PgToCI2CjpbmATcBBYEBGvAEi6FlgBDALuiIiNvbo1ZmbWJUVErWuoqFAoRLFYrHUZZmYDiqR1ZZfXH8G/yDUzy4hD38wsIw59M7OMOPTNzDLi0Dczy4hD38wsIw59M7OMOPTNzDLi0Dczy4hD38wsIw59M7OMOPTNzDLi0Dczy4hD38wsIw59M7OMOPTNzDLi0Dczy4hD38wsIw59M7OMOPTNzDLi0Dczy4hD38wsIw59M7OMOPTNzDJSVehL2iHpMUnrJRVT2zmSVne0SZqe2iVpkaRWSRskTSv7nHmStqXXvL7ZJDMzq2RwD/peHBHPls1/Bfh8RPxE0qVpfgYwG5iYXm8CFgNvkjQauBEoAAGsk9QSEb899s0wM7NqHMvwTgCvS9MjgN+k6TnAd6NkNTBS0inATGBVROxJQb8KmHUM/76ZmfVQtUf6AayUFMBtEdEMfBJYIelrlL483pz6ngrsLFu3LbVVaj+CpPnAfIDTTjut+i0xM7NuVXuk/9aImEZp6GaBpIuAjwGfiohxwKeAb/VGQRHRHBGFiCg0NTX1xkeamVlSVehHxK70vhtYDkwH5gE/TF1+kNoAdgHjylYfm9oqtZuZWT/pNvQlDZd0Qsc0cAnwOKUx/Lelbm8HtqXpFuBD6SqeC4C9EfEUsAK4RNIoSaPS56zo1a0xM7MuVTOmfzKwXFJH/+9FxP2SfgfcJGkw8BJpHB64D7gUaAV+D1wFEBF7JH0RWJv6fSEi9vTalpiZWbcUEbWuoaJCoRDFYrHWZZiZDSiS1kVEobNl/kWumVlGHPpmZhlx6JuZZcShb2aWEYe+mVlGHPpmZhlx6JuZZcShb2aWEYe+mVlGHPpmZhlx6JuZZcShb2aWEYe+mVlGHPpmZhlx6JuZZcShb2aWEYe+mVlGHPpmZhlx6JuZZcShb2aWEYe+mVlGHPpmZhlx6JuZZcShb2aWkapCX9IOSY9JWi+pmNruSvPr0/L1Zf1vkNQqaaukmWXts1Jbq6SFvb85ZmbWlcE96HtxRDzbMRMRH+iYlvR1YG+angLMBaYCrwcekPSG1PWbwLuANmCtpJaI2HRsm2BmZtXqSeh3SpKA9wNvT01zgKUR8TLwpKRWYHpa1hoR29N6S1Nfh76ZWT+pdkw/gJWS1kmaf9SyPweeiYhtaf5UYGfZ8rbUVqn9CJLmSypKKra3t1dZnpmZVaPa0H9rREwDZgMLJF1UtuyDwPd7q6CIaI6IQkQUmpqaeutjzcyMKkM/Inal993ActJwjaTBwHuAu8q67wLGlc2PTW2V2s3MrJ90G/qShks6oWMauAR4PC1+J7AlItrKVmkB5kpqlDQBmAisAdYCEyVNkDSE0snelt7bFDMz6041J3JPBpaXztcyGPheRNyfls3lqKGdiNgo6W5KJ2gPAgsi4hUASdcCK4BBwB0RsbFXtsLMzKqiiKh1DRUVCoUoFou1LsPMbECRtC4iCp0t8y9yzcwy4tA3M8uIQ9/MLCMOfTOzjNT1iVxJLwJba11HlU4Enu22V31wrX3DtfYN19pzfxoRnf669ZjvvdPHtlY6A11vJBVda+9zrX3DtfaNgVCrh3fMzDLi0Dczy0i9h35zrQvoAdfaN1xr33CtfaPua63rE7lmZta76v1I38zMepFD38wsI3Ub+gPpIeqdPTi+Xki6Q9JuSY+XtY2WtErStvQ+qpY1dqhQ6+ck7Ur7dr2kS2tZYwdJ4yQ9KGmTpI2SPpHa62rfdlFn3e1XSUMlrZH0q1Tr51P7BEmPpCy4K92avV5r/bakJ8v26zm1rvWPRETdvSjdevn/gNOBIcCvgCm1rquLencAJ9a6jgq1XQRMAx4va/sKsDBNLwS+XOs6u6j1c8Df17q2Tmo9BZiWpk8AngCm1Nu+7aLOutuvgIDj03QD8AhwAXA3MDe13wp8rI5r/TZwRa3r6+pVr0f600kPUY+I/UDHQ9SthyLi58Ceo5rnAN9J098B3t2vRVVQoda6FBFPRcQv0/SLwGZKz3yuq33bRZ11J0p+l2Yb0iuAtwPLUnvN9yl0WWvdq9fQr+oh6nWkqwfH16OTI+KpNP00pQfl1LNrJW1Iwz91MRRVTtJ44FxKR3t1u2+PqhPqcL9KGiRpPbAbWEXpL/7nI+Jg6lI3WXB0rRHRsV+/lPbrv0pqrGGJnarX0B9ounpwfF2L0t+n9XyEshj4M+Ac4Cng67Ut50iSjgf+E/hkRLxQvqye9m0nddblfo2IVyLiHErP0J4OnFHjkio6ulZJbwRuoFTz+cBo4NM1LLFT9Rr6A+oh6lHhwfF17BlJpwCk9901rqeiiHgm/ed6Ffh36mjfSmqgFKRLIuKHqbnu9m1nddbzfgWIiOeBB4ELgZGSOu4TVndZUFbrrDScFhHxMnAndbZfoX5Df8A8RL2bB8fXqxZgXpqeB/yohrV0qSNAk8upk32r0kOjvwVsjoh/KVtUV/u2Up31uF8lNUkamab/BHgXpXMQDwJXpG4136dQsdYtZV/4onTuoeb79Wh1+4vcdAnZv3H4IepfqnFJnZJ0OqWjezj84Pi6qVXS94EZlG75+gxwI3APpSsiTgN+Dbw/Imp+ArVCrTMoDUEEpaukPlo2Zl4zkt4K/AJ4DHg1Nf8jpfHyutm3XdT5Qepsv0o6i9KJ2kGUDkjvjogvpP9jSykNlzwK/E06kq6ZLmr9b6CJ0tU964Fryk741oW6DX0zM+t99Tq8Y2ZmfcChb2aWEYe+mVlGHPpmZhlx6JuZZcShb1ZGUpeX10kaX34X0Co/89uSrui+p1nfc+ibmWXEoW/WCUnHS/qppF+mZyWU3+V1sKQlkjZLWiZpWFrnPEkPpRvvrTjqV69mdcGhb9a5l4DL0430Lga+nn5aDzAJuCUiJgMvAB9P97e5mdK91M8D7gDq5pfZZh0Gd9/FLEsC/jndMfVVSrfz7bhN8s6I+N80/R/A3wH3A28EVqXvhkGU7l5pVlcc+madu5LSPVTOi4gDknYAQ9Oyo+9dEpS+JDZGxIX9V6JZz3l4x6xzI4DdKfAvBv60bNlpkjrC/a+B/wG2Ak0d7ZIaJE3t14rNquDQN+vcEqAg6THgQ8CWsmVbKT0sZzMwClicHut5BfBlSb+idIfFN/dzzWbd8l02zcwy4iN9M7OMOPTNzDLi0Dczy4hD38wsIw59M7OMOPTNzDLi0Dczy8j/A5F85l1ULq6gAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "labels_train = pd.read_csv('../tmp/labels_train.csv')\n",
    "labels_valid = pd.read_csv('../tmp/labels_valid.csv')\n",
    "n_classess = labels_train.label.unique().shape[0]\n",
    "n_classess\n",
    "labels_train.groupby(by='label').count().plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels_train.label = labels_train.label.apply(lambda x: f'{x:02d}')\n",
    "labels_valid.label = labels_valid.label.apply(lambda x: f'{x:02d}')\n",
    "# labels_train['label_bin'].values = keras.utils.np_utils.to_categorical(\n",
    "#     labels_train.label, n_classess)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 240000 validated image filenames belonging to 40 classes.\n",
      "Found 2978 validated image filenames belonging to 40 classes.\n"
     ]
    }
   ],
   "source": [
    "ig = ImageDataGenerator(preprocessing_function=preprocess_img)\n",
    "\n",
    "params_g = dict(\n",
    "    batch_size=batch_size,\n",
    "    # directory=path_data,\n",
    "    # class_mode='other',\n",
    "    x_col='fname',\n",
    "    y_col='label',\n",
    "    target_size=(img_width, img_height),\n",
    "    seed=random_seed)\n",
    "\n",
    "train_g = ig.flow_from_dataframe(\n",
    "    labels_train, path_data_train, **params_g)\n",
    "valid_g = ig.flow_from_dataframe(\n",
    "    labels_valid, path_data_valid, **params_g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Colocations handled automatically by placer.\n"
     ]
    }
   ],
   "source": [
    "base_model = Xception(\n",
    "    weights='imagenet', include_top=False, input_shape=(img_width, img_height, 3))\n",
    "\n",
    "# 添加全局平均池化层\n",
    "x = base_model.output\n",
    "x = GlobalAveragePooling2D()(x)\n",
    "\n",
    "# 添加一个全连接层\n",
    "# x = Dense(512, activation='relu')(x)\n",
    "\n",
    "x = Dense(128, activation='relu')(x)\n",
    "# x=Dropout(0.5)(x)\n",
    "\n",
    "# 添加一个分类器，假设我们有200个类\n",
    "predictions = Dense(n_classess, activation='softmax')(x)\n",
    "\n",
    "# 构建我们需要训练的完整模型\n",
    "model = Model(inputs=base_model.input, outputs=predictions)\n",
    "# model.summary()\n",
    "# 首先，我们只训练顶部的几层（随机初始化的层）\n",
    "# 锁住所有 InceptionV3 的卷积层\n",
    "# for layer in base_model.layers:\n",
    "#     layer.trainable = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "ckpt = ModelCheckpoint(\n",
    "    '../tmp/ckpt-GPU'+os.environ[\"CUDA_VISIBLE_DEVICES\"]+'-'+time.strftime('%Y-%m-%d_%H_%M')+'-Epoch_{epoch:03d}-acc_{acc:.5f}-val_acc_{val_acc:.5f}.h5', save_best_only=True, monitor='val_acc')\n",
    "estop = EarlyStopping(monitor='val_acc', min_delta=1e-7,\n",
    "                      verbose=1, patience=20)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.cast instead.\n",
      "Epoch 1/100\n",
      "7500/7500 [==============================] - 2903s 387ms/step - loss: 1.0540 - acc: 0.7265 - val_loss: 0.4482 - val_acc: 0.8656\n",
      "Epoch 2/100\n",
      "7500/7500 [==============================] - 2926s 390ms/step - loss: 0.2737 - acc: 0.9213 - val_loss: 0.4095 - val_acc: 0.8829\n",
      "Epoch 3/100\n",
      "7500/7500 [==============================] - 2927s 390ms/step - loss: 0.1065 - acc: 0.9715 - val_loss: 0.4320 - val_acc: 0.8832\n",
      "Epoch 4/100\n",
      "7500/7500 [==============================] - 2917s 389ms/step - loss: 0.0428 - acc: 0.9897 - val_loss: 0.4518 - val_acc: 0.8856\n",
      "Epoch 5/100\n",
      "7500/7500 [==============================] - 2914s 389ms/step - loss: 0.0204 - acc: 0.9955 - val_loss: 0.4878 - val_acc: 0.8809\n",
      "Epoch 6/100\n",
      "7500/7500 [==============================] - 2912s 388ms/step - loss: 0.0119 - acc: 0.9975 - val_loss: 0.5312 - val_acc: 0.8836\n",
      "Epoch 7/100\n",
      "7500/7500 [==============================] - 2903s 387ms/step - loss: 0.0086 - acc: 0.9981 - val_loss: 0.5358 - val_acc: 0.8849\n",
      "Epoch 8/100\n",
      "7500/7500 [==============================] - 2900s 387ms/step - loss: 0.0061 - acc: 0.9987 - val_loss: 0.5541 - val_acc: 0.8822\n",
      "Epoch 9/100\n",
      "7500/7500 [==============================] - 2899s 387ms/step - loss: 0.0052 - acc: 0.9988 - val_loss: 0.5633 - val_acc: 0.8873\n",
      "Epoch 10/100\n",
      "7500/7500 [==============================] - 2897s 386ms/step - loss: 0.0042 - acc: 0.9990 - val_loss: 0.5980 - val_acc: 0.8781\n",
      "Epoch 11/100\n",
      "7500/7500 [==============================] - 2899s 387ms/step - loss: 0.0036 - acc: 0.9992 - val_loss: 0.5981 - val_acc: 0.8785\n",
      "Epoch 12/100\n",
      "7500/7500 [==============================] - 2900s 387ms/step - loss: 0.0033 - acc: 0.9992 - val_loss: 0.6119 - val_acc: 0.8839\n",
      "Epoch 13/100\n",
      "7500/7500 [==============================] - 2900s 387ms/step - loss: 0.0027 - acc: 0.9993 - val_loss: 0.6049 - val_acc: 0.8846\n",
      "Epoch 14/100\n",
      "7500/7500 [==============================] - 2898s 386ms/step - loss: 0.0026 - acc: 0.9993 - val_loss: 0.6021 - val_acc: 0.8836\n",
      "Epoch 15/100\n",
      "7500/7500 [==============================] - 2893s 386ms/step - loss: 0.0024 - acc: 0.9994 - val_loss: 0.6379 - val_acc: 0.8809\n",
      "Epoch 16/100\n",
      "7500/7500 [==============================] - 2880s 384ms/step - loss: 0.0022 - acc: 0.9995 - val_loss: 0.6917 - val_acc: 0.8788\n",
      "Epoch 17/100\n",
      " 325/7500 [>.............................] - ETA: 45:40 - loss: 0.0025 - acc: 0.9994"
     ]
    }
   ],
   "source": [
    "model.compile(optimizer=Adam(lr=1e-5), loss='categorical_crossentropy',\n",
    "              metrics=['accuracy'])\n",
    "model.fit_generator(\n",
    "    train_g,\n",
    "    # steps_per_epoch=100,\n",
    "    steps_per_epoch=train_g.n // batch_size,\n",
    "    epochs=100,\n",
    "    callbacks=[ckpt, estop],\n",
    "    validation_data=valid_g,\n",
    "    # validation_steps=1,\n",
    "    validation_steps=valid_g.n // batch_size\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
